// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <functional>
#include <memory>
#include <ngraph/log.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/util.hpp>
#include <numeric>
#include <openvino/core/validation_util.hpp>
#include <openvino/op/util/pad_base.hpp>
#include <openvino/opsets/opset3.hpp>
#include <openvino/opsets/opset7.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/opsets/opset9.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <transformations/common_optimizations/nop_elimination.hpp>
#include <transformations/utils/utils.hpp>

#include "compare.hpp"
#include "itt.hpp"
#include "openvino/util/log.hpp"

using namespace std;
using namespace ov;

//`simplify_gather`, optimizes gather if Gather is gathering the
// whole input tensor
static bool simplify_gather(shared_ptr<Node> node) {
    if (auto gather = ov::as_type_ptr<op::util::GatherBase>(node)) {
        // check if we are gathering the whole input
        auto data = gather->input_value(0);
        auto indices = gather->input_value(1);

        // we need to know data and indices shape to infer if gather is Nop
        if (data.get_partial_shape().is_dynamic() || indices.get_partial_shape().is_dynamic()) {
            return false;
        }

        auto axis = gather->get_axis();
        if (axis == opset3::Gather::AXIS_NOT_SET_VALUE) {
            OPENVINO_DEBUG << "axis value not set";
            return false;
        }

        if (data.get_shape().size() != node->get_shape().size()) {
            auto constant_indices = ov::as_type_ptr<opset3::Constant>(gather->input_value(1).get_node_shared_ptr());
            if (!constant_indices)
                return false;
            // case_3: if input_shape is (1,3,5,5) and axis = 0, indices = 0, then gather is just a Squeeze
            const auto const_indices = constant_indices->cast_vector<int64_t>();
            if (data.get_shape()[axis] == 1 && const_indices.size() == 1 && const_indices[0] == 0) {
                auto squeeze = std::make_shared<opset8::Squeeze>(gather->input_value(0), gather->input_value(2));
                squeeze->set_friendly_name(gather->get_friendly_name());
                ov::copy_runtime_info(gather, squeeze);
                ov::replace_node(gather, squeeze);
                return true;
            }
            return false;
        }

        // case_1 : if the input tensor is of shape (4, 1, 4)
        // and axis = 1, then the gather would be simply
        // gathering the whole input tensor, so we can optimize this
        // op has Nop

        if (data.get_shape()[axis] == 1 && data.get_shape() == node->get_shape()) {
            return replace_output_update_name(gather->output(0), gather->input_value(0));
        }

        // case_2 : if the input tensor is of shape (4, 3, 4)
        // we need to check the contents of indices, if indices
        // is 1D tensor of value {0, 1, 2}, we can optimize this
        // op has Nop

        // check if the indices is constant
        auto constant_indices = ov::as_type_ptr<opset3::Constant>(gather->input_value(1).get_node_shared_ptr());
        if (!constant_indices) {
            return false;
        } else {
            // if ref_inidices == indices, we are capturing the
            // entire input tensor
            vector<int64_t> ref_indices(data.get_shape()[axis], 0);
            iota(ref_indices.begin(), ref_indices.end(), 0);
            if (ref_indices == constant_indices->cast_vector<int64_t>()) {
                return replace_output_update_name(gather->output(0), gather->input_value(0));
            }
        }
    }
    return false;
}

static bool eliminate_nop(const shared_ptr<Node>& node) {
    // skip if shapes are dynamic
    if (node->get_input_partial_shape(0).is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
        return false;
    }

    if (node->get_input_shape(0) == node->get_output_shape(0)) {
        return replace_output_update_name(node->output(0), node->input_value(0));
    }
    return false;
}

static bool eliminate_reshape_v1(const shared_ptr<Node>& node) {
    auto input = node->input_value(0);

    if (input.get_partial_shape().rank().is_static() && input.get_partial_shape().rank().same_scheme(1)) {
        if (input.get_partial_shape().same_scheme(node->get_output_partial_shape(0)))
            return replace_output_update_name(node->output(0), input);
    }

    // check if reshape is not identity op
    if (input.get_partial_shape().is_dynamic() || node->get_output_partial_shape(0).is_dynamic()) {
        OPENVINO_DEBUG << node << " has dynamic shapes.";
        return false;
    }
    // remove identity op
    if (input.get_shape() == node->get_output_shape(0)) {
        return replace_output_update_name(node->output(0), input);
    }
    // eliminate redundant reshape, squeeze, or unsqueeze
    auto input_node = input.get_node_shared_ptr();
    if (ov::as_type_ptr<opset3::Squeeze>(input_node) || ov::as_type_ptr<opset3::Unsqueeze>(input_node) ||
        ov::as_type_ptr<opset3::Reshape>(input_node)) {
        if (input_node->get_output_target_inputs(0).size() != 1)
            return false;

        auto shape = node->get_output_shape(0);

        // remove interchangeable nodes
        if (input_node->get_input_partial_shape(0).is_static() && input_node->get_input_shape(0) == shape) {
            return replace_output_update_name(node->output(0), input_node->input_value(0));
        } else {
            vector<int64_t> vi;
            vi.assign(shape.begin(), shape.end());
            auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{vi.size()}, vi);
            auto new_reshape = make_shared<opset3::Reshape>(input.get_node()->input_value(0), pat, false);
            new_reshape->set_friendly_name(node->get_friendly_name());
            copy_runtime_info({input_node, node}, new_reshape);
            replace_node(node, new_reshape);
            return true;
        }
    }

    return false;
}

static size_t count_unknown_dims(const PartialShape& ps) {
    size_t rc = 0;
    if (ps.is_static()) {
        return rc;
    }
    for (auto i = 0; i < ps.rank().get_length(); i++) {
        if (ps[i].is_dynamic()) {
            rc += 1;
        }
    }
    return rc;
}

static bool replace_squeeze_unsqueeze(const shared_ptr<Node>& node) {
    auto shape_ps = node->get_output_partial_shape(0);
    if (shape_ps.rank().get_length() == 0) {
        return false;
    }
    if (count_unknown_dims(shape_ps) > 1) {
        return false;
    }
    vector<int64_t> target_shape;
    for (auto i = 0; i < shape_ps.rank().get_length(); i++) {
        if (shape_ps[i].is_dynamic()) {
            target_shape.emplace_back(-1);
        } else {
            target_shape.emplace_back(shape_ps[i].get_length());
        }
    }

    shared_ptr<Node> reshape;
    auto input = node->input_value(0).get_node_shared_ptr();
    auto pat = opset3::Constant::create<int64_t>(element::i64, Shape{target_shape.size()}, target_shape);

    if (ov::is_type<opset3::Reshape>(input) || ov::is_type<opset3::Squeeze>(input) ||
        ov::is_type<opset3::Unsqueeze>(input)) {
        reshape = make_shared<opset3::Reshape>(input->input_value(0), pat, false);
    } else {
        reshape = make_shared<opset3::Reshape>(node->input_value(0), pat, false);
    }

    // skip if reshape is nop
    if (reshape->get_input_partial_shape(0).same_scheme(shape_ps)) {
        copy_runtime_info({input, node->output(0).get_node_shared_ptr()}, node->output(0).get_node_shared_ptr());
        return replace_output_update_name(node->output(0), reshape->input_value(0));
    } else {
        return replace_node_update_name(node, reshape);
    }
}

static vector<int64_t> get_unsqueeze_axes(const PartialShape& data_shape, const PartialShape& out_shape) {
    vector<int64_t> axes;
    int64_t i = 0;
    for (auto o = 0; o < out_shape.rank().get_length(); o++) {
        if (i < data_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[o])) {
            i += 1;
            continue;
        }
        if (out_shape[o].is_static() && out_shape[o] == 1) {
            axes.push_back(o);
        }
    }
    return axes;
}

static vector<int64_t> get_squeeze_axes(const PartialShape& data_shape, const PartialShape& out_shape) {
    vector<int64_t> axes;
    int64_t out_i = 0;
    for (auto i = 0; i < data_shape.rank().get_length(); i++) {
        if (out_i < out_shape.rank().get_length() && data_shape[i].same_scheme(out_shape[out_i])) {
            out_i += 1;
            continue;
        }
        if (data_shape[i].is_static() && data_shape[i] == 1) {
            axes.push_back(i);
        }
    }
    return axes;
}

static bool eliminate_unsqueeze(const shared_ptr<Node>& node) {
    auto out_shape = node->get_output_partial_shape(0);
    // try to replace all squeeze/unsqueeze with reshape
    if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
        return replace_squeeze_unsqueeze(node);
    }

    auto unsqueeze = ov::as_type_ptr<opset3::Unsqueeze>(node);
    if (unsqueeze == nullptr)
        return false;
    auto input = unsqueeze->input_value(0).get_node_shared_ptr();
    auto squeeze = ov::as_type_ptr<opset3::Squeeze>(input);
    auto replace_unsqueeze_only = [&](const vector<int64_t>& axes) {
        auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
        auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
        if (unsqueeze->get_output_partial_shape(0).same_scheme(new_unsq->get_output_partial_shape(0))) {
            return replace_node_update_name(unsqueeze, new_unsq);
        }
        return false;
    };
    // eliminate redundant squeeze->unsqueeze
    if (squeeze) {
        const auto& data_shape = squeeze->input_value(0).get_partial_shape();
        if (ngraph::compare_constants(squeeze->input_value(1).get_node_shared_ptr(),
                                      unsqueeze->input_value(1).get_node_shared_ptr())) {
            return replace_output_update_name(unsqueeze->output(0), squeeze->input_value(0));
        }
        if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
            return false;
        }
        if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
            // check if single unsqueeze can handle this
            auto axes = get_unsqueeze_axes(data_shape, out_shape);
            if (static_cast<int64_t>(axes.size()) + data_shape.rank().get_length() == out_shape.rank().get_length()) {
                return replace_unsqueeze_only(axes);
            }
        }
        if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
            // check if single squeeze can handle this
            auto axes = get_squeeze_axes(data_shape, out_shape);
            if (data_shape.rank().get_length() - static_cast<int64_t>(axes.size()) == out_shape.rank().get_length()) {
                auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
                auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
                if (unsqueeze->get_output_partial_shape(0).same_scheme(new_sq->get_output_partial_shape(0))) {
                    return replace_node_update_name(unsqueeze, new_sq);
                }
                return false;
            }
        }
        return false;
    }
    // eliminate redundant unsqueeze->unsqueeze
    auto unsqueeze_i = ov::as_type_ptr<opset3::Unsqueeze>(input);
    if (unsqueeze_i) {
        const auto& data_shape = unsqueeze_i->input_value(0).get_partial_shape();
        if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
            return false;
        }
        auto axes = get_unsqueeze_axes(data_shape, out_shape);
        return replace_unsqueeze_only(axes);
    }

    return false;
}

#define ECHO(NAME) #NAME
#define STR(NAME)  ECHO(NAME)
#define SIMPLE_MATCHER_PASS_DEFINITION(NAME, FUNC, ...)                                 \
    class NAME : public ov::pass::MatcherPass {                                         \
    public:                                                                             \
        OPENVINO_RTTI(STR(NAME), "0");                                                  \
        NAME() {                                                                        \
            MATCHER_SCOPE(NAME);                                                        \
            auto match_node = ov::pass::pattern::wrap_type<__VA_ARGS__>();              \
            ov::matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {   \
                return FUNC(m.get_match_root());                                        \
            };                                                                          \
            auto m = make_shared<ov::pass::pattern::Matcher>(match_node, matcher_name); \
            register_matcher(m, callback);                                              \
        }                                                                               \
    };

SIMPLE_MATCHER_PASS_DEFINITION(EliminateReshape, eliminate_reshape_v1, opset3::Reshape);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateUnsqueeze, eliminate_unsqueeze, opset3::Unsqueeze);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateBroadcast, eliminate_nop, op::v1::Broadcast, op::v3::Broadcast);
SIMPLE_MATCHER_PASS_DEFINITION(EliminateGather, simplify_gather, opset3::Gather, opset7::Gather, opset8::Gather);

pass::EliminatePad::EliminatePad() {
    MATCHER_SCOPE(EliminatePad);
    auto pad_node_pattern = pattern::wrap_type<op::util::PadBase>();

    matcher_pass_callback callback = [=](pattern::Matcher& m) {
        auto pad = m.get_match_root();

        OPENVINO_SUPPRESS_DEPRECATED_START
        auto pad_begin_const = get_constant_from_source(pad->input_value(1));
        auto pad_end_const = get_constant_from_source(pad->input_value(2));
        OPENVINO_SUPPRESS_DEPRECATED_END

        if (!pad_begin_const || !pad_end_const) {
            return false;
        }

        const auto pad_begin_value = pad_begin_const->cast_vector<int64_t>();
        const auto pad_end_value = pad_end_const->cast_vector<int64_t>();

        if (any_of(pad_begin_value.begin(),
                   pad_begin_value.end(),
                   [](int64_t value) {
                       return value != 0;
                   }) ||
            any_of(pad_end_value.begin(), pad_end_value.end(), [](int64_t value) {
                return value != 0;
            })) {
            return false;
        }

        return replace_output_update_name(pad->output(0), pad->input_value(0));
    };

    auto m = make_shared<pattern::Matcher>(pad_node_pattern, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateConvert::EliminateConvert() {
    MATCHER_SCOPE(EliminateConvert);
    auto convert_pattern = pattern::wrap_type<opset8::Convert>();

    matcher_pass_callback callback = [](pattern::Matcher& m) {
        auto convert = dynamic_pointer_cast<opset8::Convert>(m.get_match_root());
        if (!convert) {
            return false;
        }
        if (convert->get_input_element_type(0) == convert->get_element_type()) {
            return replace_output_update_name(convert->output(0), convert->input_value(0));
        }
        return false;
    };

    auto m = make_shared<pattern::Matcher>(convert_pattern, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateConvertNonZero::EliminateConvertNonZero() {
    MATCHER_SCOPE(EliminateConvertNonZero);
    auto convert_pattern = pattern::wrap_type<opset8::Convert>(pattern::consumers_count(1));
    auto non_zero = pattern::wrap_type<opset8::NonZero>({convert_pattern});

    matcher_pass_callback callback = [=](pattern::Matcher& m) {
        const auto& pattern_map = m.get_pattern_map();
        auto convert = pattern_map.at(convert_pattern);
        // remove convert
        convert->output(0).replace(convert->input_value(0));
        // to make this elimination recursive we register NonZero as a node which will be used to repeat matching
        register_new_node(m.get_match_root());
        return true;
    };

    auto m = make_shared<pattern::Matcher>(non_zero, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateConcat::EliminateConcat() {
    MATCHER_SCOPE(EliminateConcat);
    auto convert_pattern = pattern::wrap_type<opset8::Concat>();

    matcher_pass_callback callback = [](pattern::Matcher& m) {
        auto concat = m.get_match_root();
        if (concat->inputs().size() == 1) {
            return replace_output_update_name(concat->output(0), concat->input_value(0));
        }
        return false;
    };

    auto m = make_shared<pattern::Matcher>(convert_pattern, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateSplit::EliminateSplit() {
    MATCHER_SCOPE(EliminateSplit);
    auto convert_pattern = pattern::wrap_type<opset8::Split>();

    matcher_pass_callback callback = [](pattern::Matcher& m) {
        auto split = dynamic_pointer_cast<opset8::Split>(m.get_match_root());
        if (!split || split->get_num_splits() != 1) {
            return false;
        }
        return replace_output_update_name(split->output(0), split->input_value(0));
    };

    auto m = make_shared<pattern::Matcher>(convert_pattern, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateSqueeze::EliminateSqueeze() {
    MATCHER_SCOPE(EliminateSqueeze);
    auto squeeze_pattern = pattern::wrap_type<opset8::Squeeze>();

    matcher_pass_callback callback = [](pattern::Matcher& m) {
        const auto node = m.get_match_root();
        auto out_shape = node->get_output_partial_shape(0);
        // try to replace all unsqueeze/squeeze with reshape
        if (out_shape.rank().is_static() && out_shape.rank().get_length() != 0 && count_unknown_dims(out_shape) < 2) {
            return replace_squeeze_unsqueeze(node);
        }

        auto squeeze = ov::as_type_ptr<opset3::Squeeze>(node);
        if (squeeze == nullptr)
            return false;
        auto input = squeeze->input_value(0).get_node_shared_ptr();
        auto replace_squeeze_only = [&](const vector<int64_t>& axes) {
            auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
            auto new_sq = make_shared<opset3::Squeeze>(input->input_value(0), axes_const);
            if (squeeze->get_output_partial_shape(0).same_scheme(new_sq->get_output_partial_shape(0))) {
                return replace_node_update_name(squeeze, new_sq);
            }
            return false;
        };
        // eliminate redundant unsqueeze->squeeze
        if (auto unsqueeze = ov::as_type_ptr<opset3::Unsqueeze>(input)) {
            PartialShape data_shape;
            if (op::util::is_parameter(input)) {
                data_shape = unsqueeze->input(0).get_partial_shape();
            } else {
                data_shape = input->input(0).get_partial_shape();
            }
            if (ngraph::compare_constants(unsqueeze->input_value(1).get_node_shared_ptr(),
                                          squeeze->input_value(1).get_node_shared_ptr())) {
                return replace_output_update_name(squeeze->output(0), unsqueeze->input_value(0));
            }
            if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
                return false;
            }
            if (out_shape.rank().get_length() < data_shape.rank().get_length()) {
                // check if single squeeze can handle this
                auto axes = get_squeeze_axes(data_shape, out_shape);
                if (data_shape.rank().get_length() ==
                    out_shape.rank().get_length() + static_cast<int64_t>(axes.size())) {
                    return replace_squeeze_only(axes);
                }
            }
            if (out_shape.rank().get_length() > data_shape.rank().get_length()) {
                // check if single unsqueeze can handle this
                auto axes = get_unsqueeze_axes(data_shape, out_shape);
                if (data_shape.rank().get_length() + static_cast<int64_t>(axes.size()) ==
                    out_shape.rank().get_length()) {
                    auto axes_const = opset3::Constant::create<int64_t>(element::i64, Shape{axes.size()}, axes);
                    auto new_unsq = make_shared<opset3::Unsqueeze>(input->input_value(0), axes_const);
                    if (squeeze->get_output_partial_shape(0).same_scheme(new_unsq->get_output_partial_shape(0))) {
                        replace_output_update_name(squeeze, new_unsq);
                        return true;
                    }
                }
            }
            return false;
        }
        // eliminate redundant squeeze->squeeze
        if (auto squeeze_i = ov::as_type_ptr<opset3::Squeeze>(input)) {
            PartialShape data_shape;
            if (op::util::is_parameter(input)) {
                data_shape = squeeze_i->input(0).get_partial_shape();
            } else {
                data_shape = input->input(0).get_partial_shape();
            }
            if (data_shape.rank().is_dynamic() || out_shape.rank().is_dynamic()) {
                return false;
            }
            auto axes = get_squeeze_axes(data_shape, out_shape);
            return replace_squeeze_only(axes);
        }
        return false;
    };

    auto m = make_shared<pattern::Matcher>(squeeze_pattern, matcher_name);
    this->register_matcher(m, callback);
}

namespace {
int64_t make_positive(int64_t value, const Output<Node>& node) {
    const auto& rank = node.get_partial_shape().rank();
    if (value < 0 && rank.is_static()) {
        value = rank.get_length() + value;
    }
    return value;
};

bool check_squeeze(const shared_ptr<Node>& node) {
    auto squeeze = dynamic_pointer_cast<ov::opset9::Squeeze>(node);
    if (squeeze) {
        auto axis = dynamic_pointer_cast<ov::opset9::Constant>(squeeze->input_value(1).get_node_shared_ptr());
        if (axis) {
            auto axis_val = axis->cast_vector<int64_t>();
            if (axis_val.size() == 1 && make_positive(axis_val[0], squeeze->input_value(0)) == 1) {
                return true;
            }
        }
    }
    return false;
}

// Checks that Reshape actually equals to Squeeze op
// 0, -1 values in the shape pattern are not allowed.
bool check_reshape(const shared_ptr<Node>& node) {
    auto reshape = dynamic_pointer_cast<ov::opset9::Reshape>(node);
    if (reshape) {
        auto shape_pattern = dynamic_pointer_cast<ov::opset9::Constant>(reshape->input_value(1).get_node_shared_ptr());
        if (shape_pattern) {
            auto pattern_val = shape_pattern->cast_vector<int64_t>();
            bool is_valid_pattern = find(pattern_val.begin(), pattern_val.end(), 0) == pattern_val.end();
            is_valid_pattern =
                is_valid_pattern || find(pattern_val.begin(), pattern_val.end(), -1) == pattern_val.end();
            if (!is_valid_pattern) {
                return false;
            }
            pattern_val.insert(pattern_val.begin() + 1, 1);
            auto in_shape = reshape->input_value(0).get_partial_shape();
            // Current Reshape is a product of eliminate_reshape_v1 transformation.
            // Initial Unsqueeze operation had static input shape and thus was replaced.
            // This makes us eligible to assume input shape of Reshape that we are searching for is static
            if (in_shape.is_static() && in_shape == pattern_val) {
                return true;
            }
        }
    }
    return false;
}

bool check_axis(const shared_ptr<ov::opset9::Concat>& concat,
                const shared_ptr<Node>& split,
                bool is_special_case = false) {
    auto axis = dynamic_pointer_cast<ov::opset9::Constant>(split->input_value(1).get_node_shared_ptr());
    if (!axis) {
        return false;
    }
    const auto& axis_val = axis->cast_vector<int64_t>();
    if (axis_val.size() != 1 || (axis_val[0] != concat->get_axis() && make_positive(axis_val[0], split->output(0)) !=
                                                                          make_positive(concat->get_axis(), concat))) {
        return false;
    }

    // in case of LSTM/GRU/RNN Sequence case described below and Split/VariadicSplit op,
    // we have to check that the last slice length equals 1,
    // it corresponds output(1) of Seq op
    if (is_special_case) {
        auto last_out_shape = split->output(split->get_output_size() - 1).get_partial_shape();
        if (!last_out_shape.rank().is_static() || !last_out_shape[axis_val[0]].is_static() ||
            last_out_shape[axis_val[0]].get_length() != 1) {
            return false;
        }
    }
    return true;
}

template <class T>
shared_ptr<T> check_all_inputs(const shared_ptr<ov::opset9::Concat>& concat) {
    shared_ptr<T> split;
    const auto concat_in_values = concat->input_values();
    size_t idx = 0;
    for (const auto& in_to_concat : concat_in_values) {
        const auto& cast_to_split = dynamic_pointer_cast<T>(in_to_concat.get_node_shared_ptr());
        // There is a special case with (GRU/RNN/LSTM)Sequence ops:
        //
        // (LSTM/GRU/RNN)Sequence -- output(0) --> Squeeze (Reshape) ->Split -(H1...Hn-1 outs) ->  Concat
        //                        -- output(1) Hn out ------------------------------------------>
        //
        // Sequence->output(0) is a concatenation of H1 ... Hn from each iteration
        // Sequence->output(1) is a Hn from the last iteration
        // where n is a number of iterations
        //
        // If we found Sequence->output(0) is split into separate H1...Hn but only H1...Hn-1 are used
        // for Concat and the last input to Concat is output(1) of Sequence op, which is actually Hn,
        // this is also a valid case for this Elimination.
        if (!cast_to_split) {
            if (idx != (concat_in_values.size() - 1) || !split) {
                return {};
            }
            shared_ptr<Node> in_to_split = split->input_value(0).get_node_shared_ptr();
            Output<Node> seq_out;
            if (in_to_split && !in_to_split->inputs().empty()) {
                seq_out = in_to_split->input_value(0);
            } else {
                return {};
            }

            auto seq_node = seq_out.get_node_shared_ptr();
            if (!seq_node || seq_out.get_index() != 0 ||
                !(dynamic_pointer_cast<ov::opset9::RNNSequence>(seq_node) ||
                  dynamic_pointer_cast<ov::opset9::GRUSequence>(seq_node) ||
                  dynamic_pointer_cast<ov::opset9::LSTMSequence>(seq_node))) {
                return {};
            }

            // check that Split is connected to Sequence->output(0)
            // possible patterns:
            // Sequence:0->Squeeze->Split
            bool valid_pattern = check_squeeze(in_to_split);
            // Sequence:0->Reshape->Split
            if (!valid_pattern) {
                valid_pattern = check_reshape(in_to_split);
            }

            if (!valid_pattern) {
                return {};
            }

            // check that Sequence->output(1) is connected to this input or concat/split axis is not the same.
            if (!seq_node || in_to_concat != seq_node->output(1) || !check_axis(concat, split, true)) {
                return {};
            }
            return split;
        }
        // input (split op) should be the same for all inputs
        if (!split) {
            split = cast_to_split;
        } else if (cast_to_split.get() != split.get()) {
            // not all inputs to concat belong to the same Split op
            return {};
        }

        // Split to Concat edges are not in orderl
        // should be (0, 1, 2, ... , split->outputs().size()-1)
        if (in_to_concat.get_index() != idx) {
            return {};
        }
        ++idx;
    }

    // not all split outputs are used or concat/split axis is not the same.
    if (idx != split->outputs().size() || !check_axis(concat, split)) {
        return {};
    }

    return split;
}
}  // namespace

ov::pass::EliminateSplitConcat::EliminateSplitConcat() {
    MATCHER_SCOPE(EliminateSplitConcat);

    auto pattern_concat = pattern::wrap_type<opset8::Concat>();
    matcher_pass_callback callback = [=](pattern::Matcher& m) {
        const auto& pattern_map = m.get_pattern_map();
        const auto concat = dynamic_pointer_cast<ov::opset9::Concat>(pattern_map.at(pattern_concat));
        if (!concat) {
            return false;
        }
        shared_ptr<Node> split = check_all_inputs<ov::opset9::Split>(concat);
        if (!split) {
            split = check_all_inputs<ov::opset9::VariadicSplit>(concat);
        }

        if (!split) {
            return false;
        }

        return replace_output_update_name(concat->output(0), split->input_value(0));
    };

    auto m = make_shared<pattern::Matcher>(pattern_concat, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateTranspose::EliminateTranspose() {
    MATCHER_SCOPE(EliminateTranspose);
    auto order = pattern::wrap_type<opset8::Constant>();
    auto transpose_pattern = pattern::wrap_type<opset8::Transpose>({pattern::any_input(), order});

    matcher_pass_callback callback = [=](pattern::Matcher& m) {
        const auto& pattern_map = m.get_pattern_map();
        auto order_const = dynamic_pointer_cast<opset8::Constant>(pattern_map.at(order));
        if (!order_const) {
            return false;
        }

        const auto& order_values = order_const->cast_vector<int64_t>();
        vector<int64_t> ref_values(order_values.size());
        iota(ref_values.begin(), ref_values.end(), 0);
        if (order_values != ref_values) {
            return false;
        }

        auto transpose = m.get_match_root();
        return replace_output_update_name(transpose->output(0), transpose->input_value(0));
    };

    auto m = make_shared<pattern::Matcher>(transpose_pattern, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateEltwise::EliminateEltwise() {
    MATCHER_SCOPE(EliminateEltwise);
    auto input = pattern::any_input();
    auto constant_pattern = pattern::wrap_type<opset8::Constant>();
    auto eltwise_pattern =
        pattern::wrap_type<opset8::Add, opset8::Subtract, opset8::Multiply, opset8::Divide>({input, constant_pattern});
    auto subtract_pattern =
        pattern::wrap_type<opset8::Subtract>({input, pattern::wrap_type<opset8::Convert>({constant_pattern})});
    auto root = make_shared<pattern::op::Or>(OutputVector{eltwise_pattern, subtract_pattern});

    matcher_pass_callback callback = [=](pattern::Matcher& m) {
        const auto& pattern_map = m.get_pattern_value_map();
        auto eltwise = m.get_match_root();
        const auto& non_const_input = pattern_map.at(input);
        const auto& constant = pattern_map.at(constant_pattern);

        if (!op::util::can_eliminate_eltwise_node(eltwise, constant, non_const_input)) {
            return false;
        }
        return replace_output_update_name(eltwise->output(0), non_const_input);
    };

    auto m = make_shared<pattern::Matcher>(root, matcher_name);
    this->register_matcher(m, callback);
}

pass::EliminateScatterUpdate::EliminateScatterUpdate() {
    MATCHER_SCOPE(EliminateScatterUpdate);
    auto scatter_pattern =
        pattern::wrap_type<opset8::ScatterUpdate, opset8::ScatterNDUpdate, opset8::ScatterElementsUpdate>();

    matcher_pass_callback callback = [=](pattern::Matcher& m) {
        auto scatter = m.get_match_root();
        const auto& indices_pshape = scatter->get_input_partial_shape(1);
        const auto& updates_pshape = scatter->get_input_partial_shape(2);

        auto has_zero = [](const ov::PartialShape& shape) -> bool {
            return std::any_of(shape.cbegin(), shape.cend(), ov::cmp::Equal<ov::Dimension>(0));
        };
        if (has_zero(indices_pshape) || has_zero(updates_pshape)) {
            return replace_output_update_name(scatter->output(0), scatter->input_value(0));
        } else {
            return false;
        }
    };

    auto m = make_shared<pattern::Matcher>(scatter_pattern, matcher_name);
    this->register_matcher(m, callback);
}

ov::pass::NopElimination::NopElimination(bool use_shape_for_elimination) {
    // shape-agnostic transformations
    ADD_MATCHER_FOR_THIS(EliminatePad)
    ADD_MATCHER_FOR_THIS(EliminateConvert)
    ADD_MATCHER_FOR_THIS(EliminateConvertNonZero)
    ADD_MATCHER_FOR_THIS(EliminateConcat)
    ADD_MATCHER_FOR_THIS(EliminateSplit)
    ADD_MATCHER_FOR_THIS(EliminateTranspose)
    ADD_MATCHER_FOR_THIS(EliminateEltwise)
    using namespace ov::pass;
    ADD_MATCHER_FOR_THIS(EliminateSplitConcat)

    // shape-dependent transformations
    if (use_shape_for_elimination) {
        ADD_MATCHER_FOR_THIS(EliminateScatterUpdate)
        ADD_MATCHER_FOR_THIS(EliminateReshape)
        ADD_MATCHER_FOR_THIS(EliminateSqueeze)
        ADD_MATCHER_FOR_THIS(EliminateUnsqueeze)
        ADD_MATCHER_FOR_THIS(EliminateBroadcast)
        ADD_MATCHER_FOR_THIS(EliminateGather)
    }
}
